package algorithm.四边形不等式;

import java.util.Arrays;

/*
* 一条直线上有居民点，邮局只能建在居民点上
给定一个有序正数数组arr，每个值表示 居民点的一维坐标，再给定一个正数 num，表示邮局数量
选择num个居民点建立num个邮局，使所有的居民点到最近邮局的总距离最短，返回最短的总距离
arr=[1,2,3,4,5,1000]，num=2
第一个邮局建立在3位置，第二个邮局建立在1000位置
那么1位置到邮局的距离为2，2位置到邮局距离为1，3位置到邮局的距离为0，4位置到邮局的距离为1，5位置到邮局的距离为2
1000位置到邮局的距离为0
这种方案下的总距离为6，其他任何方案的总距离都不会比该方案的总距离更短，所以返回6
* */
public class PostOfficeProblem {

    public static int min1(int[] arr, int m){

        int[][] tmp = new int[arr.length][arr.length];
        for (int left = 0; left < arr.length; left++) {
            for (int right = left + 1; right < arr.length; right++) {
                int mid = (left + right) / 2;
                tmp[left][right] = tmp[left][right - 1] + arr[right] - arr[mid];
//                System.out.println(tmp[left][right]);
            }
        }
        int[][] dp = new int[arr.length][m];
        int[][] bestSplit = new int[arr.length][m];

        for (int i = 0; i < arr.length; i++) {
            dp[i][0] = tmp[0][i];
        }

        for (int i = 1; i < m; i++) {
            for (int j = arr.length - 1; j >= 0; j--) {
                int down = bestSplit[j][i - 1];
                int up = j + 1 >= arr.length? arr.length - 1 : bestSplit[j + 1][i];
                int min = Integer.MAX_VALUE;
                for (int k = down; k <= up; k++) {
                    int end = Math.min(k + 1, j);
                    int value = dp[k][i - 1] + tmp[end][j];
                    if(min >= value){
                        min = value;
                        bestSplit[j][i] = k;
                    }
                }
                dp[j][i] = min;
            }
        }

        return dp[arr.length - 1][m - 1];

    }

    public static int min2(int[] arr, int num) {
        if (arr == null || num < 1 || arr.length < num) {
            return 0;
        }
        int N = arr.length;
        int[][] w = new int[N + 1][N + 1];
        for (int L = 0; L < N; L++) {
            for (int R = L + 1; R < N; R++) {
                w[L][R] = w[L][R - 1] + arr[R] - arr[(L + R) >> 1];
            }
        }
        int[][] dp = new int[N][num + 1];
        for (int i = 0; i < N; i++) {
            dp[i][1] = w[0][i];
        }
        for (int i = 1; i < N; i++) {
            for (int j = 2; j <= Math.min(i, num); j++) {
                int ans = Integer.MAX_VALUE;
                for (int k = 0; k <= i; k++) {
                    ans = Math.min(ans, dp[k][j - 1] + w[k + 1][i]);
                }
                dp[i][j] = ans;
            }
        }
        return dp[N - 1][num];
    }
    // for test
    public static int[] randomSortedArray(int len, int range) {
        int[] arr = new int[len];
        for (int i = 0; i != len; i++) {
            arr[i] = (int) (Math.random() * range);
        }
        Arrays.sort(arr);
        return arr;
    }

    // for test
    public static void printArray(int[] arr) {
        for (int i = 0; i != arr.length; i++) {
            System.out.print(arr[i] + " ");
        }
        System.out.println();
    }

    // for test
    public static void main(String[] args) {
        int N = 30;
        int maxValue = 100;
        int testTime = 10000;
        System.out.println("测试开始");
        for (int i = 0; i < testTime; i++) {
            int len = (int) (Math.random() * N) + 1;
            int[] arr = randomSortedArray(len, maxValue);
            int num = (int) (Math.random() * N) + 1;
            int ans1 = min1(arr, num);
            int ans2 = min2(arr, num);
            if (ans1 != ans2) {
                printArray(arr);
                System.out.println(num);
                System.out.println(ans1);
                System.out.println(ans2);
                System.out.println("Oops!");
            }
        }
        System.out.println("测试结束");

    }

}
